{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MSC 量化器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "temp_dir = Path(\".temp\")\n",
    "temp_dir.mkdir(exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.contrib.msc.core.tools import ToolType\n",
    "# pylint: disable=import-outside-toplevel\n",
    "from tvm.contrib.msc.core.tools.quantize import QuantizeStage\n",
    "from tvm.contrib.msc.core.utils.namespace import MSCFramework\n",
    "\n",
    "run_type = MSCFramework.MSC\n",
    "if run_type == MSCFramework.TENSORRT:\n",
    "    config = {\"plan_file\": \"msc_quantizer.json\", \"strategys\": []}\n",
    "else:\n",
    "    op_types = [\"nn.conv2d\", \"msc.conv2d_bias\", \"msc.linear\", \"msc.linear_bias\"]\n",
    "    config = {\n",
    "        \"plan_file\": \"msc_quantizer.json\",\n",
    "        \"strategys\": [\n",
    "            {\n",
    "                \"methods\": {\n",
    "                    \"input\": \"gather_maxmin\",\n",
    "                    \"output\": \"gather_maxmin\",\n",
    "                    \"weights\": \"gather_max_per_channel\",\n",
    "                },\n",
    "                \"op_types\": op_types,\n",
    "                \"stages\": [QuantizeStage.GATHER],\n",
    "            },\n",
    "            {\n",
    "                \"methods\": {\"input\": \"calibrate_maxmin\", \"output\": \"calibrate_maxmin\"},\n",
    "                \"op_types\": op_types,\n",
    "                \"stages\": [QuantizeStage.CALIBRATE],\n",
    "            },\n",
    "            {\n",
    "                \"methods\": {\n",
    "                    \"input\": \"quantize_normal\",\n",
    "                    \"weights\": \"quantize_normal\",\n",
    "                    \"output\": \"dequantize_normal\",\n",
    "                },\n",
    "                \"op_types\": op_types,\n",
    "            },\n",
    "        ],\n",
    "    }\n",
    "tools = [{\"tool_type\": ToolType.QUANTIZER, \"tool_config\": config}]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import get_model_info, _test_from_torch\n",
    "\n",
    "_test_from_torch(\n",
    "    MSCFramework.TVM, tools, \n",
    "    get_model_info(MSCFramework.TVM), \n",
    "    temp_dir,\n",
    "    training=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'tool_type': 'quantizer',\n",
       "  'tool_config': {'plan_file': 'msc_quantizer.json',\n",
       "   'strategys': [{'methods': {'input': 'gather_maxmin',\n",
       "      'output': 'gather_maxmin',\n",
       "      'weights': 'gather_max_per_channel'},\n",
       "     'op_types': ['nn.conv2d',\n",
       "      'msc.conv2d_bias',\n",
       "      'msc.linear',\n",
       "      'msc.linear_bias'],\n",
       "     'stages': ['gather']},\n",
       "    {'methods': {'input': 'calibrate_maxmin', 'output': 'calibrate_maxmin'},\n",
       "     'op_types': ['nn.conv2d',\n",
       "      'msc.conv2d_bias',\n",
       "      'msc.linear',\n",
       "      'msc.linear_bias'],\n",
       "     'stages': ['calibrate']},\n",
       "    {'methods': {'input': 'quantize_normal',\n",
       "      'weights': 'quantize_normal',\n",
       "      'output': 'dequantize_normal'},\n",
       "     'op_types': ['nn.conv2d',\n",
       "      'msc.conv2d_bias',\n",
       "      'msc.linear',\n",
       "      'msc.linear_bias']}]}}]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
